{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<center>\n",
    "    <p style=\"text-align:center\">\n",
    "        <img alt=\"phoenix logo\" src=\"https://storage.googleapis.com/arize-phoenix-assets/assets/phoenix-logo-light.svg\" width=\"200\"/>\n",
    "        <br>\n",
    "        <a href=\"https://arize.com/docs/phoenix/\">Docs</a>\n",
    "        |\n",
    "        <a href=\"https://github.com/Arize-ai/phoenix\">GitHub</a>\n",
    "        |\n",
    "        <a href=\"https://arize-ai.slack.com/join/shared_invite/zt-2w57bhem8-hq24MB6u7yE_ZF_ilOYSBw#/shared-invite/email\">Community</a>\n",
    "    </p>\n",
    "</center>\n",
    "\n",
    "# Google GenAI SDK - Building an Evaluator-Optimizer Agent"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Install Dependencies"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install -q google-genai openinference-instrumentation-google-genai arize-phoenix-otel"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Connect to Arize Phoenix"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "from getpass import getpass\n",
    "\n",
    "from phoenix.otel import register\n",
    "\n",
    "if \"PHOENIX_API_KEY\" not in os.environ:\n",
    "    os.environ[\"PHOENIX_API_KEY\"] = getpass(\"🔑 Enter your Phoenix API key: \")\n",
    "\n",
    "if \"PHOENIX_COLLECTOR_ENDPOINT\" not in os.environ:\n",
    "    os.environ[\"PHOENIX_COLLECTOR_ENDPOINT\"] = getpass(\"🔑 Enter your Phoenix Collector Endpoint\")\n",
    "\n",
    "tracer_provider = register(\n",
    "    auto_instrument=True, project_name=\"google-genai-evaluator-optimizer-agent\"\n",
    ")\n",
    "tracer = tracer_provider.get_tracer(__name__)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Authenticate with Google Vertex AI"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from google import genai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!gcloud auth login"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Create a client using the Vertex AI API, you could also use the Google GenAI API instead here\n",
    "client = genai.Client(vertexai=True, project=\"<ADD YOUR GCP PROJECT ID>\", location=\"us-central1\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Evaluator Optimizer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- 1. Instantiate the model(s) ---\n",
    "# The global `genai` configuration (or the `client` created in a previous\n",
    "# notebook cell) is already set up, so we can request models right away.\n",
    "story_model = \"gemini-2.0-flash-001\"\n",
    "critic_model = \"gemini-2.0-flash-001\"\n",
    "\n",
    "# --- Configuration for iterative improvement ---\n",
    "max_iterations = 5  # Maximum number of critique/revision cycles\n",
    "min_quality_threshold = 8  # Quality threshold on a scale of 1-10\n",
    "\n",
    "\n",
    "# --- 2. Generate the initial short story ---\n",
    "@tracer.agent()\n",
    "def generate_story(story_model, client):\n",
    "    initial_story = create_initial_story(story_model, client)\n",
    "    return iteratively_improve_story(initial_story, story_model, critic_model, client)\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def create_initial_story(story_model, client):\n",
    "    story_prompt = (\n",
    "        \"You are a creative short story writer.\\n\"\n",
    "        \"Write a brief, engaging story (3–4 paragraphs) about an unexpected \"\n",
    "        \"adventure.\\n\"\n",
    "        \"Be imaginative but concise.\\n\"\n",
    "        \"User Input: \"\n",
    "    )\n",
    "    story_response = client.models.generate_content(\n",
    "        model=story_model,\n",
    "        contents=story_prompt + input(\"Please enter a prompt or seed for your story: \"),\n",
    "    )\n",
    "    current_story = story_response.text.strip()\n",
    "\n",
    "    print(\"=== Initial Story ===\\n\")\n",
    "    print(current_story)\n",
    "\n",
    "    return current_story\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def iteratively_improve_story(current_story, story_model, critic_model, client):\n",
    "    iteration = 0\n",
    "    story_quality = 0  # Initial quality score\n",
    "\n",
    "    while iteration < max_iterations and story_quality < min_quality_threshold:\n",
    "        iteration += 1\n",
    "        print(f\"\\n--- Iteration {iteration} ---\")\n",
    "\n",
    "        # Generate critique and extract quality score\n",
    "        critique_text, story_quality = critique_story(\n",
    "            current_story, critic_model, client, iteration\n",
    "        )\n",
    "\n",
    "        # Check if revision is needed\n",
    "        if story_quality >= min_quality_threshold:\n",
    "            print(\n",
    "                f\"\\nStory quality ({story_quality}/10) meets or exceeds threshold ({min_quality_threshold}/10).\"\n",
    "            )\n",
    "            print(\"No further revisions needed.\")\n",
    "            break\n",
    "\n",
    "        # Improve the story based on the critique\n",
    "        current_story = revise_story(current_story, critique_text, story_model, client, iteration)\n",
    "\n",
    "    if iteration >= max_iterations and story_quality < min_quality_threshold:\n",
    "        print(\n",
    "            f\"\\n--- Maximum iterations ({max_iterations}) reached without meeting quality threshold ---\"\n",
    "        )\n",
    "\n",
    "    print(\"\\n--- Final Story ---\")\n",
    "    print(current_story)\n",
    "    return current_story\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def critique_story(current_story, critic_model, client, iteration):\n",
    "    print(\"\\nGenerating critique and quality assessment...\")\n",
    "    critic_prompt = (\n",
    "        \"You are a literary critic.\\n\"\n",
    "        \"First, rate the story on a scale of 1-10 (where 10 is excellent).\\n\"\n",
    "        \"Then analyze the provided story for its strengths and weaknesses.\\n\"\n",
    "        \"Comment on plot, characters, and overall impact.\\n\"\n",
    "        \"If the story needs improvement, provide 2–3 specific suggestions.\\n\"\n",
    "        \"Format your response as follows:\\n\"\n",
    "        \"Quality Score: [1-10]\\n\"\n",
    "        \"Critique: [Your detailed critique]\\n\"\n",
    "        \"Suggestions: [Your specific suggestions if score is below 8]\\n\\n\"\n",
    "        f\"Story:\\n{current_story}\"\n",
    "    )\n",
    "    critic_response = client.models.generate_content(model=critic_model, contents=critic_prompt)\n",
    "    critique_text = critic_response.text.strip()\n",
    "\n",
    "    # Display the critique\n",
    "    print(f\"\\n=== Critique {iteration} ===\\n\")\n",
    "    print(critique_text)\n",
    "\n",
    "    # Extract quality score from critique\n",
    "    try:\n",
    "        # Try to extract the quality score from the critique\n",
    "        quality_line = [line for line in critique_text.split(\"\\n\") if \"Quality Score:\" in line]\n",
    "        if quality_line:\n",
    "            story_quality = int(quality_line[0].split(\":\")[1].strip().split()[0])\n",
    "        else:\n",
    "            story_quality = 0  # Default if not found\n",
    "    except (ValueError, IndexError):\n",
    "        story_quality = 0  # Default if parsing fails\n",
    "\n",
    "    print(f\"\\nQuality Score: {story_quality}/10\")\n",
    "\n",
    "    return critique_text, story_quality\n",
    "\n",
    "\n",
    "@tracer.chain()\n",
    "def revise_story(current_story, critique_text, story_model, client, iteration):\n",
    "    print(\"\\nGenerating revision...\")\n",
    "    revision_prompt = (\n",
    "        \"You are a creative short story writer.\\n\"\n",
    "        \"Revise the following story based *only* on the critique provided.\\n\"\n",
    "        \"Focus on addressing the specific suggestions mentioned in the critique.\\n\"\n",
    "        \"Do not introduce significant new elements not prompted by the critique.\\n\"\n",
    "        \"Make substantial improvements to raise the quality score.\\n\\n\"\n",
    "        f\"Critique:\\n{critique_text}\\n\\n\"\n",
    "        f\"Original Story:\\n{current_story}\\n\\n\"\n",
    "        \"Please produce an improved version of the story, addressing the suggestions.\\n\"\n",
    "        \"Output *only* the revised story.\"\n",
    "    )\n",
    "    revision_response = client.models.generate_content(model=story_model, contents=revision_prompt)\n",
    "    revised_story = revision_response.text.strip()\n",
    "\n",
    "    # Display the improved story for this iteration\n",
    "    print(f\"\\n=== Improved Story (Iteration {iteration}) ===\\n\")\n",
    "    print(revised_story)\n",
    "\n",
    "    return revised_story\n",
    "\n",
    "\n",
    "generate_story(story_model, client)"
   ]
  }
 ],
 "metadata": {
  "language_info": {
   "name": "python"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
